-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add factorized llama model for testing. #604
Conversation
6c5679c
to
3a84009
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry it took so long! Super busy week!
|
||
_loss_fn = hax.filter_checkpoint(_per_layer_loss) | ||
|
||
loss, _ = hax.fold( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is probably the best way to do this. You can of course use scan if you wanted to keep per-layer losses for logging or weighting or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this took a bit of back and forth to figure out but seems sensible now.
Initially I used an unrolled loop, which "worked" on a tiny model but effectively stalled in compilation for the larger model sizes.
When I switched to scan, I ran into a few issues, most the normal JAX/TPU stuff, which shouldn't affect most regular models:
- The
is_scanned
magic forhax.fold
was trying to scan over theinitial
values in addition to the layers (because it by default tries to glom any namedarray into the scan). Maybe it's worth exempting the first argument from that logic? - It took a while to realize I had to enable gradient checkpointing. Everything "works" on a CPU test, but just OOMs the process. Trying it on a TPU yields XLA errors messages about trying to allocate >1TB of memory and asking for magic flags to actually tell you the large tensors. (And of course, when you provide those flags, it doesn't change anything...). I ended up needing to reduce the size of the model to convince XLA I was close enough to fitting for it to actually bother to do the allocation analysis and cough up the errors...
I'm also obviously special-casing to the stacked
variant of the transformer, which feels a little gross, but... I suspect it's not worth generalizing the hof.py
modules to support this.
d[prefix + ".down_proj.weight"] = down_proj | ||
d[prefix + ".up_proj.weight"] = up_proj | ||
|
||
d.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i need to finish the branch but this should be very simple once it's merged. You'll have to do your lowrank approximation manually but otherwise all of this will be handled for you.
In case you're curious/want to provide feedback, the branches are:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's much nicer indeed. I like the separation of determining the saveable state and then transforming it to HF, and having Stacked
et. al handle more of the load.
The linearization and out_first
logic was definitely tedious to work with when I was trying to get this working. I ended up creating one of those weird models where every dimension is a prime number to figure out what was going where :). (It didn't help that I couldn't figure out the real shape that SVD was outputting...)
return down_proj, up_proj | ||
|
||
|
||
class FactorizedLinear(StateDictSerializationMixin, eqx.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so one thing you could do is do all this as a tree transformation/model surgery, similar to how we do it for Lora.
Basically you load the model as normal, then you do something like:
def replace_linear(x):
if not isinstance(x, Linear): return x
up_proj, down_proj = low_rank_approximation(x)
return LowRankLinear(up_proj, down_proj)
modified_model = jax.tree_util.tree_map(replace_linear, model, is_leaf=lambda x: isinstance(x, Linear))
I'm pretty sure you could delete most of this file if you did this.
This version only works if use_scan_layers is false. You have to do some fanciness (that we do in Lora) for scan layers to essentially vmap the replace_linear layer whenever you detect a stacked. (Maybe I should make a stacked_aware_tree_map or something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's a great idea. You'd lose a little flexibility in the size of your low-rank layers, but it would be nice to work across all models.
I'll try that out once I have something working end-to-end. (I don't mind starting with this and throwing it away, since it's a bit easier to debug this way than also debugging any wonky transform logic I try to write...
Oops, that got lost in my last attempt to refactor. I should have a test for that. Dropout is super slow on TPU so I tend to avoid it...
Noted, the branches in Haliax/Levanter should make most of the "bookkeeping" parts go away, and I think if you moved to a "tree transformation" based approach, I think you shouldn't have to modify the core Llama code at all. (I think?!?)
I'm all ears (though please do see the twin state dict branches and lmk what you think!)
|
We all do! I think JAX needs better docs / examples for the "not a complete beginner but not a JAX whisperer"
Yeah... still trying to figure this out. some bits like the batch loader helpers and some of the mixed precision/sharding logic should be "lego-ified" and split out. Probably other stuff too. I think my goal should be to make copy-paste not feel bad, if that makes sense.
hrm. please send me a stack trace if you figure it out
Ah yeah, I don't love that. I'd like to lego-ify that too
hrm. I do think the timings in the compiler are pretty carefully tuned for the "standard" case and maybe you were tripping it up? I do wish they'd offered PGO for TPU like they're starting to do for GPU... though maybe that's not the problem here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback!
I think things are working, mostly, on the training side now: I can at least run steps. Unfortunately when I try to init the models from HF, things get stuck again for some reason. I'll need to add some more logging to see if I can figure out what's triggering that.
|
||
_loss_fn = hax.filter_checkpoint(_per_layer_loss) | ||
|
||
loss, _ = hax.fold( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this took a bit of back and forth to figure out but seems sensible now.
Initially I used an unrolled loop, which "worked" on a tiny model but effectively stalled in compilation for the larger model sizes.
When I switched to scan, I ran into a few issues, most the normal JAX/TPU stuff, which shouldn't affect most regular models:
- The
is_scanned
magic forhax.fold
was trying to scan over theinitial
values in addition to the layers (because it by default tries to glom any namedarray into the scan). Maybe it's worth exempting the first argument from that logic? - It took a while to realize I had to enable gradient checkpointing. Everything "works" on a CPU test, but just OOMs the process. Trying it on a TPU yields XLA errors messages about trying to allocate >1TB of memory and asking for magic flags to actually tell you the large tensors. (And of course, when you provide those flags, it doesn't change anything...). I ended up needing to reduce the size of the model to convince XLA I was close enough to fitting for it to actually bother to do the allocation analysis and cough up the errors...
I'm also obviously special-casing to the stacked
variant of the transformer, which feels a little gross, but... I suspect it's not worth generalizing the hof.py
modules to support this.
d[prefix + ".down_proj.weight"] = down_proj | ||
d[prefix + ".up_proj.weight"] = up_proj | ||
|
||
d.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's much nicer indeed. I like the separation of determining the saveable state and then transforming it to HF, and having Stacked
et. al handle more of the load.
The linearization and out_first
logic was definitely tedious to work with when I was trying to get this working. I ended up creating one of those weird models where every dimension is a prime number to figure out what was going where :). (It didn't help that I couldn't figure out the real shape that SVD was outputting...)
return down_proj, up_proj | ||
|
||
|
||
class FactorizedLinear(StateDictSerializationMixin, eqx.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's a great idea. You'd lose a little flexibility in the size of your low-rank layers, but it would be nice to work across all models.
I'll try that out once I have something working end-to-end. (I don't mind starting with this and throwing it away, since it's a bit easier to debug this way than also debugging any wonky transform logic I try to write...
+1, and overall I think things have been fine: if it's been slow getting things working, it's because I hit some indecipherable XLA or JAX error vs having to hack around with Levanter. If this was just another model I don't see any reason I'd presumably be able to reuse the default training setup easily. But I suspect there are some parts here we can make easier to use while preserving the existing workflow.
I'll let you know if I can reproduce. I couldn't understand it because all that stuff is outside of the JIT scope, but I thought, "no need for StepInfo until I can take a step..."
Yeah there's certainly some tuning issues with the compiler if you don't take the right approach. I didn't investigate it too much in this case though: it might just have produced a giant program and stalled trying to compile it... |
6d40fc2
to
e5ec7ca
Compare
Can you open an issue. I agree this is better |
(Not intended for real review! But here's the current hacked up status of trying to get a layerwise trainer going in case it provides food for thought.) I can't for the life of me figure out how to get Github to acknowledge a file copy and show a diff: if there's a tip for how to do this with git, please let me know. Brain dump:
Random things:
state.training_key
isn't updated with thenew_key
when we take a step: https://github.com/stanford-crfm/levanter/blob/main/src/levanter/trainer.py#L495 . IIUC, we'll use the same randomness for every batch as a result. (I could be horribly wrong here). That said, it doesn't look like we use dropout or any other runtime- randomness, so it probably doesn't change anything either way!Good things:
There are a lot of my bad decisions here:
scan_layers
to make it easy for me to invoke the layers individually, but from our discussion that cripples performance: I wonder if I need to be more clever here?